K-NN Python

import numpy as np
import matplotlib.pyplot as plt
from collections import Counter

def k_nearest_neighbor( the_k, dataset, predictd_date):

    distance_temp = []
    for group in dataset:
	for data in dataset[group]:
	    c_distance = np.linalg.norm(np.array(predictd_date) - np.array(data))
	    distance_temp.append([c_distance, group])

    sorted_group = [distance[1] for distance in sorted(distance_temp)]
    top_k_nearest = sorted_group[:the_k]

    result = Counter(top_k_nearest).most_common(1)[0][0]
    prediction = Counter(top_k_nearest).most_common(1)[0][1] *1.0/the_k

    return result, prediction

if __name__=='__main__':

    dataset = {'black':[ [1,2], [2,3], [3,1] ], 'red':[ [6,5], [7,7], [8,6] ]}
    new_features = [ [6.5, 7.2], [3.1, 3.2] ]

    #plot all dataset
    Figure = plt.figure()
    Figure_sub = Figure.add_subplot(111)

    for group in dataset:
	for data in dataset[group]:
	    Figure_sub.scatter(data[0], data[1], s=50, color=group)

    Result, prediction = k_nearest_neighbor(3, dataset, new_features)

    Figure_sub.scatter(new_features[0], new_features[1], s=60, color=Result)
    plt.show()

Author: dctrl

Created: 2018-11-08 Thu 19:26

Validate